embedding = torch.Tensor([[[2,1,3],
                           [3,4,5]],
                          [[4,5,6],
                           [5,8,9]] ])
print(embedding.size()) #[2, 2, 3]
layer_norm = nn.LayerNorm(normalized_shape=3,elementwise_affine=False)
s = layer_norm(embedding)
print(s) 
